Skip to content

Conversation

jingyu-ml
Copy link
Contributor

@jingyu-ml jingyu-ml commented Sep 15, 2025

What does this PR do?

Type of change: new feature

Overview:

  1. Here we added the support of exporting the dynamic block quantization in ONNX for both mha and liner layer.
  2. Fixed a minor bug in diffusion example.

Usage

FP8_SAGE_DEFAULT_CONFIG = {
    "quant_cfg": {
        "*weight_quantizer": {"num_bits": (4, 3), "axis": None},
        "*input_quantizer": {"num_bits": (4, 3), "axis": None},
        "*output_quantizer": {"enable": False},
        "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}},
        "*softmax_quantizer": {
            "num_bits": (4, 3),
            "axis": None,
        },
        "default": {"enable": False},
    },
    "algorithm": "max",
}

mtq.quantize(model, FP8_SAGE_DEFAULT_CONFIG, forward_func)

torch.onnx.export(model, ...) # you can follow the diffusion example at https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/diffusers/quantization

Testing

  1. This is an evaluation feature since the TRT kernel isn’t ready. No test cases are required at this time, we will add the test case next month.
  2. However, we can continue relying on the existing FP8 per-tensor test cases as usual.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?:No

Additional Information

Summary by CodeRabbit

  • New Features

    • Optional FP8 blockwise quantization for tensors and attention (Q/K/V), configurable at runtime and preserved in ONNX export.
    • New FP8 SAGE preset enabling dynamic blockwise quantization for attention matmuls.
  • Improvements

    • MHA quantization driven by runtime configuration rather than a class default.
    • Per-dimension block-size support propagated end-to-end through attention and export paths.
    • Cleanup to avoid lingering temporary quantization state after forward calls.

@jingyu-ml jingyu-ml self-assigned this Sep 15, 2025
@jingyu-ml jingyu-ml requested review from a team as code owners September 15, 2025 23:06
Copy link

coderabbitai bot commented Sep 15, 2025

Walkthrough

Adds optional FP8 blockwise quantization end-to-end: new FP8_SAGE_DEFAULT_CONFIG, runtime-driven quantize_mha flag, propagation of per-tensor vs dynamic block shapes through diffusers attention and ONNX symbolic/export paths, TensorQuantizer and ScaledE4M3 support for block sizes, and ONNX helpers for blockwise FP8 quantize/dequantize.

Changes

Cohort / File(s) Summary
Diffusers config & usage
examples/diffusers/quantization/config.py, examples/diffusers/quantization/quantize.py
Adds FP8_SAGE_DEFAULT_CONFIG (dynamic QKV block settings). quantize.py now uses quant_config.quantize_mha at runtime.
ONNX FP8 export (blockwise support)
modelopt/torch/quantization/export_onnx.py
Adds _fp8_block_quantize and _fp8_block_dequantize. export_fp8 accepts `amax: float
Tensor quantizer internals
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Adds _get_block_sizes_list(self, shape); stores _original_input_shape during forward and clears it after; FP8 fake-quant path derives per-dimension block_sizes_list and passes it into scaled_e4m3.
Diffusers FP8 SDPA path
modelopt/torch/quantization/plugins/diffusers.py
Detects dynamic vs non-dynamic Q/K/V quantizers; computes q/k/v block sizes and passes them through; updates _QuantAttention, FP8SDPA.forward, and FP8SDPA.symbolic signatures to accept and propagate q_block_shape, k_block_shape, v_block_shape.
Tensor quant function + symbolic
modelopt/torch/quantization/tensor_quant.py
ScaledE4M3Function parse_args/symbolic updated to include block_sizes; forward signature adds block_sizes; ONNX export now passes block_sizes to export_fp8.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant QC as QuantConfig
  participant Attn as _QuantAttention/FP8SDPA
  participant Sym as FP8SDPA.symbolic
  participant Export as export_fp8_mha
  participant BlockHelpers as _fp8_block_quant/_fp8_block_dequant

  User->>QC: load FP8_SAGE_DEFAULT_CONFIG
  User->>Attn: forward(query, key, value, ...)
  Attn->>Attn: detect dynamic vs non-dynamic Q/K/V
  Attn->>Attn: compute q/k/v block_shapes via _get_block_sizes_list()
  Attn->>Sym: symbolic(..., q_block_shape, k_block_shape, v_block_shape)
  Sym->>Export: export_fp8_mha(..., q_block_shape, k_block_shape, v_block_shape)
  alt block shapes present
    Export->>BlockHelpers: _fp8_block_quantize(Q/K/V, block_shape)
    BlockHelpers-->>Export: quantized uint8 + scales
    Export->>BlockHelpers: _fp8_block_dequantize(..., block_shape)
    BlockHelpers-->>Export: dequantized tensors
  else no block shapes
    Export->>Export: per-tensor FP8 quantize/dequantize
  end
  Export-->>User: ONNX graph with FP8 (blockwise or per-tensor)
Loading
sequenceDiagram
  autonumber
  participant TensorQ as TensorQuantizer
  participant SE4 as ScaledE4M3Function
  participant Export as export_fp8
  participant BlockHelpers as _fp8_block_quant/_fp8_block_dequant

  TensorQ->>TensorQ: _get_block_sizes_list(_original_input_shape)
  TensorQ->>SE4: forward(x, scale, amax, block_sizes_list, ...)
  alt ONNX export path
    SE4->>Export: export_fp8(..., amax=None|float, block_sizes)
    opt block_sizes provided
      Export->>BlockHelpers: _fp8_block_quantize/_fp8_block_dequantize
    end
  else eager fake-quant
    SE4->>SE4: apply fake-quant with block sizes
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

Poem

I nibble bytes in tidy blocks and map each tiny shape,
Q, K, V hop in tiles and glide—no scale shall escape.
I dig through graphs and stitch the paths where blockwise carrots lie,
ONNX lanes now hum with crumbs; I twitch my whiskers, sigh.
Hop—export done, a carrot crunch beneath the sky. 🥕🐇

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "FP8 Block quantize onnx export support" succinctly and accurately captures the primary change in the PR—adding blockwise FP8 quantization support for ONNX export (including MHA and linear layers) as shown in the summary and code changes; it is short, focused, and relevant for a reviewer scanning history.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jingyux/block-quant-onnx

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 35f3da2 and 25be640.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/tensor_quant.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/tensor_quant.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@jingyu-ml jingyu-ml marked this pull request as draft September 15, 2025 23:06
Copy link

copy-pr-bot bot commented Sep 15, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml force-pushed the jingyux/block-quant-onnx branch from e4d1775 to 071f167 Compare September 15, 2025 23:14
@jingyu-ml jingyu-ml requested a review from kaix-nv September 15, 2025 23:15
@jingyu-ml jingyu-ml marked this pull request as ready for review September 15, 2025 23:15
Signed-off-by: Jingyu Xin <[email protected]>
Copy link

codecov bot commented Sep 15, 2025

Codecov Report

❌ Patch coverage is 17.07317% with 34 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.71%. Comparing base (bbb2304) to head (9e88a34).

Files with missing lines Patch % Lines
modelopt/torch/quantization/export_onnx.py 8.33% 22 Missing ⚠️
.../torch/quantization/nn/modules/tensor_quantizer.py 28.57% 10 Missing ⚠️
modelopt/torch/quantization/tensor_quant.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #324      +/-   ##
==========================================
- Coverage   73.82%   73.71%   -0.12%     
==========================================
  Files         172      172              
  Lines       17438    17471      +33     
==========================================
+ Hits        12874    12879       +5     
- Misses       4564     4592      +28     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (11)
examples/diffusers/quantization/config.py (1)

39-39: Fix spacing inconsistency in configuration.

The configuration has inconsistent spacing after commas. Line 39 has a missing space after the comma between (4, 3) and "block_sizes".

-        "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}},
+        "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}},
modelopt/torch/quantization/export_onnx.py (2)

237-263: Consider adding validation for block_sizes parameter.

The new _fp8_block_quantize function should validate the structure and values of the block_sizes parameter to prevent runtime errors during ONNX export.

Add validation at the beginning of the function:

 def _fp8_block_quantize(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
     trt_high_precision_dtype: str,
     block_sizes: list,
 ):
     """Helper Function for Quantization."""
+    if not isinstance(block_sizes, list) or not block_sizes:
+        raise ValueError(f"block_sizes must be a non-empty list, got {block_sizes}")
+    if not all(isinstance(b, int) and b > 0 for b in block_sizes):
+        raise ValueError(f"All block sizes must be positive integers, got {block_sizes}")
+        
     output_shape = sym_help._get_tensor_sizes(inputs)

534-535: Fix typo in comment.

-        # We cannot do block quant for the softmax's output 
+        # We cannot do block quant for the softmax's output
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

653-653: Remove trailing whitespace from blank lines.

These blank lines contain unnecessary whitespace which violates Python style guidelines.

-        
+
         Args:
             shape: The tensor shape to use for conversion (can be tuple or torch.Size)
-            
+
         Returns:
             List of block sizes for each dimension, or None if block_sizes is None
-            
+
         Example:

Also applies to: 656-656, 659-659


961-962: Consider thread-safety for the _original_input_shape attribute.

Setting and deleting _original_input_shape as a temporary attribute could cause issues in multi-threaded scenarios where the same quantizer is used by multiple threads simultaneously.

Consider using a context manager or local variable approach instead:

-            setattr(self, "_original_input_shape", inputs.shape)
-            inputs = self._process_for_blockquant(inputs)
+            original_shape = inputs.shape
+            inputs = self._process_for_blockquant(inputs)
+            # Pass original_shape to methods that need it

Alternatively, consider storing it in a thread-local storage if multi-threading support is required.

modelopt/torch/quantization/plugins/diffusers.py (6)

117-124: Fix assertion logic for mixed dynamic/non-dynamic quantizers

The current implementation requires all QKV quantizers to be either dynamic or non-dynamic together. However, the logic flow is problematic - if they're all non-dynamic, scales are computed, but if any is dynamic, it asserts all must be dynamic. This creates a rigid constraint that may not be necessary for all use cases.

Consider refactoring to handle mixed cases more gracefully:

-    if not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic:
-        q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
-        k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
-        v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
-    else:
-        assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type"
-        q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None
+    # Compute scales for non-dynamic quantizers, set None for dynamic ones
+    q_quantized_scale = None if self.q_bmm_quantizer._dynamic else self.q_bmm_quantizer._get_amax(query)
+    k_quantized_scale = None if self.k_bmm_quantizer._dynamic else self.k_bmm_quantizer._get_amax(key)
+    v_quantized_scale = None if self.v_bmm_quantizer._dynamic else self.v_bmm_quantizer._get_amax(value)
+    
+    # Optionally validate consistency if needed
+    dynamic_states = [self.q_bmm_quantizer._dynamic, self.k_bmm_quantizer._dynamic, self.v_bmm_quantizer._dynamic]
+    if len(set(dynamic_states)) > 1:
+        # Log warning or handle mixed dynamic states if necessary
+        pass

122-122: Fix line length violation

Line 122 exceeds the 120 character limit (149 characters).

-        assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type"
+        assert (self.q_bmm_quantizer._dynamic and 
+                self.k_bmm_quantizer._dynamic and 
+                self.v_bmm_quantizer._dynamic), "QKV QDQS must be in the same type"

144-146: Remove trailing whitespace

Line 145 has trailing whitespace after the comma.

             q_block_sizes,
-            k_block_sizes, 
+            k_block_sizes,
             v_block_sizes,

231-233: Inconsistent default values for scale parameters

The scale parameters have inconsistent default values in the symbolic method signature (float | None = 1.0) which doesn't match the forward method where they default to None.

-        q_quantized_scale: float | None = 1.0,
-        k_quantized_scale: float | None = 1.0,
-        v_quantized_scale: float | None = 1.0,
+        q_quantized_scale: float | None = None,
+        k_quantized_scale: float | None = None,
+        v_quantized_scale: float | None = None,

200-202: Consider using TypeAlias for block shape type consistency

The block shape parameters use list | None type annotations repeatedly. Consider defining a type alias for better maintainability and consistency.

Add at the top of the file after imports:

from typing import TypeAlias

BlockShape: TypeAlias = list[int] | None

Then update the signatures:

-        q_block_shape: list | None = None,
-        k_block_shape: list | None = None,
-        v_block_shape: list | None = None,
+        q_block_shape: BlockShape = None,
+        k_block_shape: BlockShape = None,
+        v_block_shape: BlockShape = None,

Also applies to: 236-238


126-128: Add validation for block sizes consistency

The code retrieves block sizes from quantizers but doesn't validate that they're compatible with the actual tensor shapes or the quantization configuration.

Consider adding validation to ensure block sizes are appropriate:

     # Get block sizes lists for each quantizer if needed
     q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape)
     k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape)
     v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape)
+    
+    # Validate block sizes if dynamic quantization is enabled
+    if self.q_bmm_quantizer._dynamic and q_block_sizes:
+        for dim, block_size in enumerate(q_block_sizes):
+            if block_size > 1 and query.shape[dim] % block_size != 0:
+                raise ValueError(f"Query dimension {dim} (size {query.shape[dim]}) is not divisible by block size {block_size}")
+    # Similar validation for k and v can be added
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76e8ce2 and 071f167.

📒 Files selected for processing (6)
  • examples/diffusers/quantization/config.py (1 hunks)
  • examples/diffusers/quantization/quantize.py (1 hunks)
  • modelopt/torch/quantization/export_onnx.py (6 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5 hunks)
  • modelopt/torch/quantization/plugins/diffusers.py (6 hunks)
  • modelopt/torch/quantization/tensor_quant.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • _get_amax (540-549)
  • _get_block_sizes_list (651-672)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • trt_high_precision_dtype (407-409)
  • trt_high_precision_dtype (412-413)
  • block_sizes (289-291)
  • block_sizes (294-296)
  • amax (233-238)
  • amax (241-252)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (7)
  • amax (233-238)
  • amax (241-252)
  • block_sizes (289-291)
  • block_sizes (294-296)
  • trt_high_precision_dtype (407-409)
  • trt_high_precision_dtype (412-413)
  • forward (902-1004)
modelopt/torch/quantization/export_onnx.py (1)
  • export_fp8 (321-342)
🪛 GitHub Actions: Code Quality
modelopt/torch/quantization/plugins/diffusers.py

[error] 122-122: E501 Line too long (149 > 120).


[error] 134-136: mypy: Item 'str' of 'str | float | None' has no attribute 'shape' [union-attr] (lines 134-136).

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

[warning] 653-653: W293 Blank line contains whitespace.


[warning] 656-656: W293 Blank line contains whitespace.


[warning] 659-659: W293 Blank line contains whitespace.


[warning] 959-959: RUF003 Comment contains ambiguous RIGHT SINGLE QUOTATION MARK ’. Did you mean `` (GRAVE ACCENT)?

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
🔇 Additional comments (6)
examples/diffusers/quantization/quantize.py (1)

942-942: LGTM! Runtime configuration for MHA quantization is properly handled.

The change from QuantizationConfig.quantize_mha to quant_config.quantize_mha correctly uses the runtime configuration for MHA quantization. This aligns with the overall pattern of making MHA quantization configurable.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

651-673: LGTM! Well-designed helper method for block size conversion.

The _get_block_sizes_list method is well-implemented with proper handling of both positive and negative dimension indices. The documentation with examples clearly explains its purpose.

modelopt/torch/quantization/export_onnx.py (1)

298-298: Incorrect — default block_sizes is not exercised by current callers

Call sites found in modelopt/torch/quantization/export_onnx.py pass block_sizes into _fp8_block_dequantize, so the hardcoded [1,1,128,1] default isn't used; no change required.

Likely an incorrect or invalid review comment.

modelopt/torch/quantization/tensor_quant.py (1)

415-415: No action required — parse_args('i' → 'is') is backward-compatible.

'is' accepts both integer and integer-list, so existing callers that pass integers remain valid (repo tests call scaled_e4m3(..., 4, 3)).

modelopt/torch/quantization/plugins/diffusers.py (2)

87-148: Good implementation of block quantization propagation!

The changes effectively extend the FP8 quantization path to support dynamic block quantization:

  • Proper handling of dynamic vs non-dynamic quantizers
  • Clean propagation of block shape parameters through the call stack
  • Maintains backward compatibility with existing code

221-221: parse_args: 'is' is correct for int-list (or None)

Matches existing usage in modelopt/torch/quantization/tensor_quant.py — the three "is" entries correctly map the three block-shape parameters to int[] | None.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

959-963: Ensure _original_input_shape is always cleaned up (wrap in try/finally)

State can leak if an exception occurs between setting and deleting _original_input_shape. Wrap the forward section after setup in try/finally.

-        if (
+        cleanup_original_shape = False
+        if (
             self.block_sizes is not None
             and self.block_sizes.get("type", None) != "dynamic"
             and self._fake_quant
         ):
             # Reshape is required if the logic isnt handled in the simulation kernel
             self._setup_for_blockquant(inputs)
             setattr(self, "_original_input_shape", inputs.shape)
+            cleanup_original_shape = True
             inputs = self._process_for_blockquant(inputs)
 
-        outputs = inputs
+        try:
+            outputs = inputs
             ...
-        if hasattr(self, "_original_input_shape"):
-            delattr(self, "_original_input_shape")
+        finally:
+            if cleanup_original_shape and hasattr(self, "_original_input_shape"):
+                delattr(self, "_original_input_shape")

Also applies to: 1002-1003

🧹 Nitpick comments (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

681-699: Comment is misleading within dynamic-only branch

The comment says “Double scale Block quantization, including dynamic and static block quantization” but this branch executes only when type == "dynamic". Tighten the comment to avoid confusion.

-            # Double scale Block quantization, including dynamic and static block quantization
+            # Dynamic double-scale block quantization path
modelopt/torch/quantization/plugins/diffusers.py (2)

117-132: Align QKV mode and error message; avoid computing scales in export path

  • Message “QKV QDQS must be in the same type” is unclear. Make it explicit: “Q, K, and V quantizers must all be dynamic or all be static.”
  • Skip _get_amax when exporting; it’s unused at runtime and can be None for dynamic. Guard by torch.onnx.is_in_onnx_export().
-    if (
+    if (
         not self.q_bmm_quantizer._dynamic
         and not self.k_bmm_quantizer._dynamic
         and not self.v_bmm_quantizer._dynamic
     ):
-        q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
-        k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
-        v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
+        if not torch.onnx.is_in_onnx_export():
+            q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
+            k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
+            v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
+        else:
+            q_quantized_scale = k_quantized_scale = v_quantized_scale = None
     else:
         assert (
             self.q_bmm_quantizer._dynamic
             and self.k_bmm_quantizer._dynamic
             and self.v_bmm_quantizer._dynamic
-        ), "QKV QDQS must be in the same type"
+        ), "Q, K, and V quantizers must all be dynamic or all be static."
         q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None

133-137: Using a private helper across modules; consider promoting to public API

Calling _get_block_sizes_list from another module couples to a private method. Expose it as a public helper (e.g., get_block_sizes_list or a util function) to avoid brittle dependencies.

Would you like me to extract a small utility, e.g., modelopt/torch/quantization/utils/block_sizes.py:get_block_sizes_list(shape, block_sizes)?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 071f167 and 831c32d.

📒 Files selected for processing (4)
  • examples/diffusers/quantization/config.py (1 hunks)
  • modelopt/torch/quantization/export_onnx.py (6 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5 hunks)
  • modelopt/torch/quantization/plugins/diffusers.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/diffusers/quantization/config.py
  • modelopt/torch/quantization/export_onnx.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • _get_amax (540-549)
  • _get_block_sizes_list (651-672)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
🔇 Additional comments (2)
modelopt/torch/quantization/plugins/diffusers.py (2)

151-155: Default disable flag may be surprising

You pass True when _disable_fp8_mha is absent, which disables FP8 MHA by default. Confirm this is intended for evaluation builds, or flip default to False.

Would you like a config flag gate to avoid silent disablement in production?


251-269: Verified — export_fp8_mha signature matches callsite

Definition in modelopt/torch/quantization/export_onnx.py (def export_fp8_mha at ~line 420) includes q_block_shape, k_block_shape, v_block_shape; the call in modelopt/torch/quantization/plugins/diffusers.py passes them — no mismatch found.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

703-716: Fix scaled_e4m3 call-site argument order

scaled_e4m3 now expects block_sizes before bias; update all callers to the new signature.

  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py:707 (new usage)
  • modelopt/torch/quantization/calib/histogram.py:311
  • tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py:55
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py:148, 158, 166, 173, 185, 187, 202

Use: scaled_e4m3(inputs, amax, block_sizes, bias, E, M, ...). If no block_sizes, pass None as the third argument and move bias to the fourth.

♻️ Duplicate comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

959-963: Guarantee cleanup of _original_input_shape on exceptions.

Deletion is unconditional but not exception-safe; wrap the quantization region in try/finally. Also fix the typo “isnt” → “isn't”.

-        if (
+        cleanup_original_input_shape = False
+        if (
             self.block_sizes is not None
             and self.block_sizes.get("type", None) != "dynamic"
             and self._fake_quant
         ):
-            # Reshape is required if the logic isnt handled in the simulation kernel
+            # Reshape is required if the logic isn't handled in the simulation kernel
             self._setup_for_blockquant(inputs)
             setattr(self, "_original_input_shape", inputs.shape)
+            cleanup_original_input_shape = True
             inputs = self._process_for_blockquant(inputs)
 
-        outputs = inputs
+        try:
+            outputs = inputs
@@
-        if (
+            if (
                 self.block_sizes is not None
                 and self.block_sizes.get("type", None) != "dynamic"
                 and self._fake_quant
-        ):
-            outputs = self._reset_to_original_shape(outputs)
-
-        if hasattr(self, "_original_input_shape"):
-            delattr(self, "_original_input_shape")
-        return outputs
+            ):
+                outputs = self._reset_to_original_shape(outputs)
+            return outputs
+        finally:
+            if cleanup_original_input_shape and hasattr(self, "_original_input_shape"):
+                delattr(self, "_original_input_shape")

Also applies to: 1002-1003

🧹 Nitpick comments (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

651-673: Type-safety and input validation for _get_block_sizes_list.

Add explicit typing, validate keys, and guard length mismatches to avoid silently passing malformed shapes to downstream ONNX ops.

-def _get_block_sizes_list(self, shape):
+from typing import Sequence
+
+def _get_block_sizes_list(self, shape: Sequence[int] | torch.Size) -> list[int] | None:
@@
-        block_sizes_list = []
-        for dim in range(len(shape)):
+        # Only allow integer axes plus known metadata keys.
+        valid_meta = {"type", "scale_bits", "scale_block_sizes"}
+        assert all(
+            isinstance(k, int) or k in valid_meta for k in self.block_sizes.keys()
+        ), f"Invalid block_sizes keys: {list(self.block_sizes.keys())}"
+
+        rank = len(shape)
+        block_sizes_list: list[int] = []
+        for dim in range(rank):
             # Check both positive and negative dimension indices
-            dim_negative = dim - len(shape)
+            dim_negative = dim - rank
             block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
             block_sizes_list.append(block_size if block_size is not None else 1)
         return block_sizes_list
modelopt/torch/quantization/export_onnx.py (2)

238-265: Validate block_shape rank and surface clearer errors.

Add a rank check before emitting TRT_DynamicQuantize; mis-sized block_shapes currently fall through to TRT with cryptic errors.

 def _fp8_block_quantize(
@@
-    input_type = inputs.type().scalarType()
+    input_type = inputs.type().scalarType()
+    rank = symbolic_helper._get_tensor_rank(inputs)
+    assert rank is not None, "Input rank must be known at export time."
+    assert len(block_sizes) == rank, (
+        f"block_shape length ({len(block_sizes)}) must match input rank ({rank})."
+    )
@@
     quantized_output, scales_output = g.op(
         "trt::TRT_DynamicQuantize",
         inputs,
         block_shape_i=block_sizes,

503-509: Block-shape consistency in FP8 MHA path.

Validate q/k/v block shapes match input ranks; also ensure softmax path never receives a block shape.

-        query_scaled = export_fp8(
-            g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape
-        )
+        assert (q_block_shape is None) or (
+            len(q_block_shape) == symbolic_helper._get_tensor_rank(query_scaled)
+        ), "q_block_shape rank mismatch."
+        query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape)
@@
-        key_transposed_scaled = export_fp8(
-            g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape
-        )
+        assert (k_block_shape is None) or (
+            len(k_block_shape) == symbolic_helper._get_tensor_rank(key_transposed_scaled)
+        ), "k_block_shape rank mismatch."
+        key_transposed_scaled = export_fp8(g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape)
@@
-        # We cannot do block quant for the softmax's output
-        attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None)
+        # We cannot do block quant for the softmax's output
+        attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None)
@@
-        value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape)
+        assert (v_block_shape is None) or (
+            len(v_block_shape) == symbolic_helper._get_tensor_rank(value)
+        ), "v_block_shape rank mismatch."
+        value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape)

Also applies to: 535-549

modelopt/torch/quantization/plugins/diffusers.py (2)

117-132: Q/K/V quantization mode must match: improve error and skip redundant work.

The assertion is good. Minor: clarify message and avoid computing per-tensor amax if any quantizer is dynamic.

-        ), "QKV QDQS must be in the same type"
+        ), "Q/K/V quantization modes must match: either all dynamic or all static."

133-137: Guard block size list creation when block_sizes is None.

_if a quantizer has no block_sizes, _get_block_sizes_list returns None; that’s fine. Add a quick comment to make intent explicit and future-proof.

-    q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape)  # type: ignore[union-attr]
+    # Returns None for per-tensor paths; ONNX export handles that by taking the non-block path.
+    q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape)  # type: ignore[union-attr]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 071f167 and 831c32d.

📒 Files selected for processing (4)
  • examples/diffusers/quantization/config.py (1 hunks)
  • modelopt/torch/quantization/export_onnx.py (6 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5 hunks)
  • modelopt/torch/quantization/plugins/diffusers.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/diffusers/quantization/config.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • trt_high_precision_dtype (407-409)
  • trt_high_precision_dtype (412-413)
  • block_sizes (289-291)
  • block_sizes (294-296)
  • amax (233-238)
  • amax (241-252)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • _get_amax (540-549)
  • _get_block_sizes_list (651-672)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
🔇 Additional comments (1)
modelopt/torch/quantization/plugins/diffusers.py (1)

229-231: Keep 't' for q/k/v scales — export extracts consts

export_onnx.py already extracts constant floats (uses sym_help._get_const / _maybe_get_const for scale/amax), so the current parse_args ("..., 't','t','t', ...") is fine; only change those three to 'f' if the export_fp8 const-extraction fix is removed. Location: modelopt/torch/quantization/plugins/diffusers.py (symbolic decorator around lines 229–231; same check applies to lines ~241–249).

Signed-off-by: Jingyu Xin <[email protected]>
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please share a sample command for ONNX exporting for a supported model as well in the description.

):
# Tensor reshaping is required for static block quantization
# Tensor shapes are handled separately by the quantization kernels for dynamic block quantization
# Reshape is required if the logic isnt handled in the simulation kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: isn't

How do we check when the reshape is or isn't handled in the simulation kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernels are here: https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/modelopt/torch/quantization/src

The kernels only support MX format and reshaped NVFP4. Other formats require using Torch reshape. I think the previously comment made a false statement, just add one more comment.

@staticmethod
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b")
@symbolic_helper.parse_args(
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this pattern is changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check parse_args:
https://github.com/pytorch/pytorch/blob/main/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py#L301

Each value represents the input type, which helps Torch trace the graph more effectively.
Since the block shape is a list of integers, we add three "is".

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

649-671: Return None for “metadata‑only” block_sizes and add typing.

Avoid forcing block mode with [1,…,1] when there are no per‑dim entries; this can route ONNX path incorrectly.

Apply:

-    def _get_block_sizes_list(self, shape):
+    def _get_block_sizes_list(self, shape) -> list[int] | None:
         """Convert block_sizes dict to list format based on tensor shape.
@@
-        if self.block_sizes is None:
+        if self.block_sizes is None:
             return None
-
-        block_sizes_list = []
+        # If there are no integer dimension entries with a meaningful block size, treat as no block quant.
+        has_dim_sizes = any(
+            isinstance(k, int) and (v is not None and v != 1) for k, v in self.block_sizes.items()
+        )
+        if not has_dim_sizes:
+            return None
+
+        block_sizes_list: list[int] = []
         for dim in range(len(shape)):
             # Check both positive and negative dimension indices
             dim_negative = dim - len(shape)
             block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
             block_sizes_list.append(block_size if block_size is not None else 1)
         return block_sizes_list

1005-1007: Centralize deletion of _original_input_shape in a finally block.

Move this into a single finally at the end of forward so it runs regardless of success/failure.

Apply:

-        if hasattr(self, "_original_input_shape"):
-            delattr(self, "_original_input_shape")
-        return outputs
+        try:
+            return outputs
+        finally:
+            if hasattr(self, "_original_input_shape"):
+                delattr(self, "_original_input_shape")
modelopt/torch/quantization/export_onnx.py (3)

294-321: Remove brittle default [1,1,128,1] and validate rank in _fp8_block_dequantize.

Defaulting silently is dangerous and shape-dependent. Require explicit block_sizes and assert correctness. (Same concern raised earlier.)

Apply:

-def _fp8_block_dequantize(
+def _fp8_block_dequantize(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
     scales: torch.Value,
     trt_high_precision_dtype: str,
     otype: str | None = None,
-    block_sizes: list = [1, 1, 128, 1],
+    block_sizes: list,
 ):
     """Helper Function for Dequantization."""
     output_shape = sym_help._get_tensor_sizes(inputs)
+    # Validate block shape
+    rank = sym_help._get_tensor_rank(inputs)
+    assert rank is not None, "Input rank must be known at export time."
+    assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, (
+        f"block_shape length ({len(block_sizes)}) must match input rank ({rank})."
+    )
+    assert all(isinstance(b, int) and b > 0 for b in block_sizes), (
+        "All entries in block_shape must be positive integers."
+    )
+    if otype is None:
+        otype = inputs.type().scalarType()

323-345: Handle non-Python amax safely and validate block_shapes before block Q/DQ path.

float(amax) will break when amax is a graph Value/0‑dim tensor; also assert block_shapes align with input rank before calling block ops. (Echoing prior comment.)

Apply:

 def export_fp8(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
-    amax: float | None,
+    amax: float | None,
     trt_high_precision_dtype: str | None,
     block_sizes: list | None,
 ):
     """Export quantized model to FP8 ONNX."""
-    scale = 1.0 if amax is None else 448.0 / float(amax)
+    if amax is None:
+        scale = 1.0
+    else:
+        amax_const = sym_help._get_const(amax, "f", "amax")
+        # If not a constant at export time, fall back to neutral scale to avoid exporter errors.
+        scale = 1.0 if (amax_const is None or amax_const == 0) else 448.0 / float(amax_const)
@@
-    if not block_sizes:
+    if not block_sizes:
         q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
         return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
     else:
+        # Validate block shape early
+        rank = sym_help._get_tensor_rank(inputs)
+        assert rank is not None, "Input rank must be known at export time."
+        assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, (
+            f"block_shape length ({len(block_sizes)}) must match input rank ({rank})."
+        )
+        assert all(isinstance(b, int) and b > 0 for b in block_sizes), (
+            "All entries in block_shape must be positive integers."
+        )
         q_tensor, scales_output = _fp8_block_quantize(
             g, inputs, trt_high_precision_dtype, block_sizes
         )
         return _fp8_block_dequantize(
             g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes
         )

238-265: Validate block_shape against input rank and values in _fp8_block_quantize.

Guard against mismatched ranks and non-positive entries to avoid invalid custom op attributes at export time.

Apply:

 def _fp8_block_quantize(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
     trt_high_precision_dtype: str,
     block_sizes: list,
 ):
     """Helper Function for Quantization."""
     output_shape = sym_help._get_tensor_sizes(inputs)
 
+    # Validate block shape
+    rank = sym_help._get_tensor_rank(inputs)
+    assert rank is not None, "Input rank must be known at export time."
+    assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, (
+        f"block_shape length ({len(block_sizes)}) must match input rank ({rank})."
+    )
+    assert all(isinstance(b, int) and b > 0 for b in block_sizes), (
+        "All entries in block_shape must be positive integers."
+    )
🧹 Nitpick comments (2)
modelopt/torch/quantization/export_onnx.py (1)

512-518: Pre-validate q/k block_shapes vs tensor ranks to fail fast.

Catch mismatches early instead of deep inside TRT ops.

Apply:

-        query_scaled = export_fp8(
+        # Sanity-check block shapes
+        for name, t, bs in (("q", query_scaled, q_block_shape), ("k", key_transposed_scaled, k_block_shape)):
+            if bs is not None:
+                r = sym_help._get_tensor_rank(t)
+                assert r is not None and len(bs) == r, f"{name}_block_shape must match rank ({r})."
+        query_scaled = export_fp8(
             g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape
         )
@@
-        key_transposed_scaled = export_fp8(
+        key_transposed_scaled = export_fp8(
             g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape
         )
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

962-965: Guarantee cleanup of _original_input_shape via try/finally; fix typo.

Ensure attribute is deleted even on exceptions; also fix “isnt” -> “isn't”.

Apply:

-            # Reshape is required if the logic isnt handled in the simulation kernel
-            self._setup_for_blockquant(inputs)
-            setattr(self, "_original_input_shape", inputs.shape)
-            inputs = self._process_for_blockquant(inputs)
+            # Reshape is required if the logic isn't handled in the simulation kernel
+            cleanup_original_shape = False
+            try:
+                self._setup_for_blockquant(inputs)
+                setattr(self, "_original_input_shape", inputs.shape)
+                cleanup_original_shape = True
+                inputs = self._process_for_blockquant(inputs)
+            except Exception:
+                # Make sure we don't leak transient attributes on failure
+                if cleanup_original_shape and hasattr(self, "_original_input_shape"):
+                    delattr(self, "_original_input_shape")
+                raise
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d2c6e0f and 0af26b2.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/export_onnx.py (6 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • trt_high_precision_dtype (405-407)
  • trt_high_precision_dtype (410-411)
  • block_sizes (287-289)
  • block_sizes (292-294)
  • amax (231-236)
  • amax (239-250)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/quantization/export_onnx.py (2)

545-549: Softmax stays non-block — good call.

Keeping softmax output per-tensor avoids mismatched block semantics across BMM2.


558-559: Block‑quantizing V path looks correct.

The cast order aligns with the earlier Q/K path.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

701-714: Pass block_sizes to scaled_e4m3 by keyword to avoid positional drift.

rg found only call sites (tests and histogram.py); I couldn't locate the scaled_e4m3 definition—confirm it accepts a block_sizes kwarg. If it does, apply:

         outputs = scaled_e4m3(
             inputs,
             amax,
             self._get_bias(inputs),
             E,
             M,
             self._trt_high_precision_dtype,
             self._pass_through_bwd,
-            block_sizes_list,
+            block_sizes=block_sizes_list,
         )

Signed-off-by: Jingyu Xin <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

649-671: Don’t force “block mode” when block_sizes has only metadata; return None.

Current logic returns [1, ...] even if block_sizes has no per‑dim sizes, unintentionally enabling block handling downstream. Return None when there are no integer keys with v not in {None, 1}. Add typing while here.

-    def _get_block_sizes_list(self, shape):
+    def _get_block_sizes_list(self, shape) -> list[int] | None:
@@
-        if self.block_sizes is None:
-            return None
-
-        block_sizes_list = []
+        if self.block_sizes is None:
+            return None
+        # If there are no per-dimension entries with a real block size, treat as no block quant.
+        has_dim_sizes = any(
+            isinstance(k, int) and (v is not None and v != 1)
+            for k, v in self.block_sizes.items()
+        )
+        if not has_dim_sizes:
+            return None
+
+        block_sizes_list: list[int] = []
         for dim in range(len(shape)):
             # Check both positive and negative dimension indices
             dim_negative = dim - len(shape)
             block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
             block_sizes_list.append(block_size if block_size is not None else 1)
         return block_sizes_list

1006-1007: Guarantee deletion of _original_input_shape with try/finally (prev feedback).

Move deletion into a finally block guarded by the local flag so it always runs.

-        if hasattr(self, "_original_input_shape"):
-            delattr(self, "_original_input_shape")
-        return outputs
+        try:
+            return outputs
+        finally:
+            if cleanup_original_input_shape and hasattr(self, "_original_input_shape"):
+                delattr(self, "_original_input_shape")
🧹 Nitpick comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

962-966: Make _original_input_shape cleanup exception‑safe: set flag when set, clean in finally.

Set a local flag when creating the attribute so the finally block can reliably clean it even if an exception occurs later.

         if (
             self.block_sizes is not None
             and self.block_sizes.get("type", None) != "dynamic"
             and self._fake_quant
         ):
             # Reshape is required if the logic is not handled in the simulation kernel
             # Only MX format and NVFP4 reshape are currently supported by the kernel.
             self._setup_for_blockquant(inputs)
-            setattr(self, "_original_input_shape", inputs.shape)
+            setattr(self, "_original_input_shape", inputs.shape)
+            cleanup_original_input_shape = True
             inputs = self._process_for_blockquant(inputs)

Add the flag near the top of forward (before this block):

-        # Rotating the input
+        cleanup_original_input_shape = False
+        # Rotating the input
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0af26b2 and 35f3da2.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

701-704: LGTM: using pre‑reshape shape to derive per‑dim blocks.

Using _original_input_shape avoids mismapping after reshape/flatten.


705-714: Pass block_sizes as a keyword; confirm scaled_e4m3 signature and update callers

Change this call to pass block_sizes by name to avoid positional-argument drift; before merging, confirm the scaled_e4m3 definition accepts a named block_sizes parameter (or update all callers if signature changed). Location: modelopt/torch/quantization/nn/modules/tensor_quantizer.py (around lines 705–714).

             outputs = scaled_e4m3(
                 inputs,
                 amax,
                 self._get_bias(inputs),
                 E,
                 M,
                 self._trt_high_precision_dtype,
                 self._pass_through_bwd,
-                block_sizes_list,
+                block_sizes=block_sizes_list,
             )

Quick verification commands to run locally:

  • rg -nP 'def\s+scaled_e4m3\s*(' -C2
  • rg -nP '\bscaled_e4m3\s*(' -C2
  • rg -nP '\bscaled_e4m3\s*([^)]block_sizes\s=' -n

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants